{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "view-in-github"
   },
   "source": [
    "<a href=\"https://colab.research.google.com/github/NeuromatchAcademy/course-content/blob/master/tutorials/W2D1_DeepLearning/W2D1_Tutorial3.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Neuromatch Academy: Week 2, Day 1, Tutorial 3\n",
    "# Deep Learning: Building and Evaluating Normative Encoding Models\n",
    "\n",
    "**Content creators**: Jorge A. Menendez, Yalda Mohsenzadeh, Carsen Stringer\n",
    "\n",
    "**Conent reviewers**: Roozbeh Farhoodi, Madineh Sarvestani, Kshitij Dwivedi, Spiros Chavlis, Ella Batty, Michael Waskom\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "#Tutorial Objectives\n",
    "\n",
    "\n",
    "In this tutorial, we'll be using deep learning to build an encoding model of the visual system, and then compare its internal representations to those observed in neural data.\n",
    "\n",
    "Importantly, the encoding model we'll use here is different from the encoding models used in Tutorial 2. Its parameters won't be directly optimized to fit the neural data. Instead, we will optimize its parameters to solve a particular visual task that we know the brain can solve. We therefore refer to it as a \"normative\" encoding model, since it is optimized for a specific behavioral task.\n",
    "\n",
    "To then evaluate whether this normative encoding model is actually a good model of the brain, we'll analyze its internal representations and compare them to the representations observed in mouse primary visual cortex. Since we understand exactly what the encoding model's representations are optimized to do, any similarities will hopefully shed light on why the representations in the brain look the way they do.\n",
    "\n",
    "More concretely, our goal will be learn how to:\n",
    "* Visualize and analyze the internal representations of a deep network\n",
    "* Quantify the similarity between distributed representations in a model and neural representations observed in recordings, using Representational Similarity Analysis (RSA)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "# Setup\n",
    "\n",
    "**Don't forget to execute the hidden cells below!**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "both",
    "execution": {
     "iopub.execute_input": "2021-05-31T20:50:02.122267Z",
     "iopub.status.busy": "2021-05-31T20:50:02.121722Z",
     "iopub.status.idle": "2021-05-31T20:50:03.111491Z",
     "shell.execute_reply": "2021-05-31T20:50:03.110267Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy.stats import zscore\n",
    "import matplotlib as mpl\n",
    "from matplotlib import pyplot as plt\n",
    "import torch\n",
    "from torch import nn, optim\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.manifold import TSNE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {
     "iopub.execute_input": "2021-05-31T20:50:03.120178Z",
     "iopub.status.busy": "2021-05-31T20:50:03.119670Z",
     "iopub.status.idle": "2021-05-31T20:50:03.170192Z",
     "shell.execute_reply": "2021-05-31T20:50:03.169140Z"
    }
   },
   "outputs": [],
   "source": [
    "#@title Data retrieval and loading\n",
    "import os\n",
    "import hashlib\n",
    "import requests\n",
    "\n",
    "fname = \"W3D4_stringer_oribinned1.npz\"\n",
    "url = \"https://osf.io/683xc/download\"\n",
    "expected_md5 = \"436599dfd8ebe6019f066c38aed20580\"\n",
    "\n",
    "if not os.path.isfile(fname):\n",
    "  try:\n",
    "    r = requests.get(url)\n",
    "  except requests.ConnectionError:\n",
    "    print(\"!!! Failed to download data !!!\")\n",
    "  else:\n",
    "    if r.status_code != requests.codes.ok:\n",
    "      print(\"!!! Failed to download data !!!\")\n",
    "    elif hashlib.md5(r.content).hexdigest() != expected_md5:\n",
    "      print(\"!!! Data download appears corrupted !!!\")\n",
    "    else:\n",
    "      with open(fname, \"wb\") as fid:\n",
    "        fid.write(r.content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {
     "iopub.execute_input": "2021-05-31T20:50:03.176566Z",
     "iopub.status.busy": "2021-05-31T20:50:03.175570Z",
     "iopub.status.idle": "2021-05-31T20:50:03.204807Z",
     "shell.execute_reply": "2021-05-31T20:50:03.204332Z"
    }
   },
   "outputs": [],
   "source": [
    "#@title Figure Settings\n",
    "\n",
    "%matplotlib inline\n",
    "%config InlineBackend.figure_format='retina'\n",
    "plt.style.use(\"https://raw.githubusercontent.com/NeuromatchAcademy/course-content/master/nma.mplstyle\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {
     "iopub.execute_input": "2021-05-31T20:50:03.214560Z",
     "iopub.status.busy": "2021-05-31T20:50:03.213399Z",
     "iopub.status.idle": "2021-05-31T20:50:03.215212Z",
     "shell.execute_reply": "2021-05-31T20:50:03.215687Z"
    }
   },
   "outputs": [],
   "source": [
    "#@title Plotting Functions\n",
    "\n",
    "def show_stimulus(img, ax=None):\n",
    "  \"\"\"Visualize a stimulus\"\"\"\n",
    "  if ax is None:\n",
    "    ax = plt.gca()\n",
    "  ax.imshow(img, cmap=mpl.cm.binary)\n",
    "  ax.set_aspect('auto')\n",
    "  ax.set_xticks([])\n",
    "  ax.set_yticks([])\n",
    "  ax.spines['left'].set_visible(False)\n",
    "  ax.spines['bottom'].set_visible(False)\n",
    "\n",
    "def plot_corr_matrix(rdm, ax=None):\n",
    "  \"\"\"Plot dissimilarity matrix\n",
    "\n",
    "  Args:\n",
    "    rdm (numpy array): n_stimuli x n_stimuli representational dissimilarity\n",
    "      matrix\n",
    "    ax (matplotlib axes): axes onto which to plot\n",
    "\n",
    "  Returns:\n",
    "    nothing\n",
    "\n",
    "  \"\"\"\n",
    "  if ax is None:\n",
    "    ax = plt.gca()\n",
    "  image = ax.imshow(rdm, vmin=0.0, vmax=2.0)\n",
    "  ax.set_xticks([])\n",
    "  ax.set_yticks([])\n",
    "  cbar = plt.colorbar(image, ax=ax, label='dissimilarity')\n",
    "\n",
    "\n",
    "def plot_multiple_rdm(rdm_dict):\n",
    "  \"\"\"Draw multiple subplots for each RDM in rdm_dict.\"\"\"\n",
    "  fig, axs = plt.subplots(1, len(rdm_dict),\n",
    "                          figsize=(4 * len(resp_dict), 3.5))\n",
    "\n",
    "  # Compute RDM's for each set of responses and plot\n",
    "  for i, (label, rdm) in enumerate(rdm_dict.items()):\n",
    "\n",
    "    # Uncomment to test your function\n",
    "    image = plot_corr_matrix(rdm, axs[i])\n",
    "    axs[i].set_title(label)\n",
    "\n",
    "\n",
    "def plot_rdm_rdm_correlations(rdm_sim):\n",
    "  \"\"\"Draw a bar plot showing between-RDM correlations.\"\"\"\n",
    "  f, ax = plt.subplots()\n",
    "  ax.bar(rdm_sim.keys(), rdm_sim.values())\n",
    "  ax.set_xlabel('Deep network model layer')\n",
    "  ax.set_ylabel('Correlation of model layer RDM\\nwith mouse V1 RDM')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {
     "iopub.execute_input": "2021-05-31T20:50:03.239415Z",
     "iopub.status.busy": "2021-05-31T20:50:03.228563Z",
     "iopub.status.idle": "2021-05-31T20:50:03.241221Z",
     "shell.execute_reply": "2021-05-31T20:50:03.240833Z"
    }
   },
   "outputs": [],
   "source": [
    "#@title Helper Functions\n",
    "\n",
    "def load_data(data_name=fname, bin_width=1):\n",
    "  \"\"\"Load mouse V1 data from Stringer et al. (2019)\n",
    "\n",
    "  Data from study reported in this preprint:\n",
    "  https://www.biorxiv.org/content/10.1101/679324v2.abstract\n",
    "\n",
    "  These data comprise time-averaged responses of ~20,000 neurons\n",
    "  to ~4,000 stimulus gratings of different orientations, recorded\n",
    "  through Calcium imaginge. The responses have been normalized by\n",
    "  spontanous levels of activity and then z-scored over stimuli, so\n",
    "  expect negative numbers. They have also been binned and averaged\n",
    "  to each degree of orientation.\n",
    "\n",
    "  This function returns the relevant data (neural responses and\n",
    "  stimulus orientations) in a torch.Tensor of data type torch.float32\n",
    "  in order to match the default data type for nn.Parameters in\n",
    "  Google Colab.\n",
    "\n",
    "  This function will actually average responses to stimuli with orientations\n",
    "  falling within bins specified by the bin_width argument. This helps\n",
    "  produce individual neural \"responses\" with smoother and more\n",
    "  interpretable tuning curves.\n",
    "\n",
    "  Args:\n",
    "    bin_width (float): size of stimulus bins over which to average neural\n",
    "      responses\n",
    "\n",
    "  Returns:\n",
    "    resp (torch.Tensor): n_stimuli x n_neurons matrix of neural responses,\n",
    "        each row contains the responses of each neuron to a given stimulus.\n",
    "        As mentioned above, neural \"response\" is actually an average over\n",
    "        responses to stimuli with similar angles falling within specified bins.\n",
    "    stimuli: (torch.Tensor): n_stimuli x 1 column vector with orientation\n",
    "        of each stimulus, in degrees. This is actually the mean orientation\n",
    "        of all stimuli in each bin.\n",
    "\n",
    "  \"\"\"\n",
    "  with np.load(data_name) as dobj:\n",
    "    data = dict(**dobj)\n",
    "  resp = data['resp']\n",
    "  stimuli = data['stimuli']\n",
    "\n",
    "  if bin_width > 1:\n",
    "    # Bin neural responses and stimuli\n",
    "    bins = np.digitize(stimuli, np.arange(0, 360 + bin_width, bin_width))\n",
    "    stimuli_binned = np.array([stimuli[bins == i].mean() for i in np.unique(bins)])\n",
    "    resp_binned = np.array([resp[bins == i, :].mean(0) for i in np.unique(bins)])\n",
    "  else:\n",
    "    resp_binned = resp\n",
    "    stimuli_binned = stimuli\n",
    "\n",
    "  # only use stimuli <= 180\n",
    "  resp_binned = resp_binned[stimuli_binned <= 180]\n",
    "  stimuli_binned = stimuli_binned[stimuli_binned <= 180]\n",
    "\n",
    "  stimuli_binned -= 90  # 0 means vertical, -ve means tilted left, +ve means tilted right\n",
    "\n",
    "  # Return as torch.Tensor\n",
    "  resp_tensor = torch.tensor(resp_binned, dtype=torch.float32)\n",
    "  stimuli_tensor = torch.tensor(stimuli_binned, dtype=torch.float32).unsqueeze(1)  # add singleton dimension to make a column vector\n",
    "\n",
    "  return resp_tensor, stimuli_tensor\n",
    "\n",
    "def grating(angle, sf=1 / 28, res=0.1, patch=False):\n",
    "  \"\"\"Generate oriented grating stimulus\n",
    "\n",
    "  Args:\n",
    "    angle (float): orientation of grating (angle from vertical), in degrees\n",
    "    sf (float): controls spatial frequency of the grating\n",
    "    res (float): resolution of image. Smaller values will make the image\n",
    "      smaller in terms of pixels. res=1.0 corresponds to 640 x 480 pixels.\n",
    "    patch (boolean): set to True to make the grating a localized\n",
    "      patch on the left side of the image. If False, then the\n",
    "      grating occupies the full image.\n",
    "\n",
    "  Returns:\n",
    "    torch.Tensor: (res * 480) x (res * 640) pixel oriented grating image\n",
    "\n",
    "  \"\"\"\n",
    "\n",
    "  angle = np.deg2rad(angle)  # transform to radians\n",
    "\n",
    "  wpix, hpix = 640, 480  # width and height of image in pixels for res=1.0\n",
    "\n",
    "  xx, yy = np.meshgrid(sf * np.arange(0, wpix * res) / res, sf * np.arange(0, hpix * res) / res)\n",
    "\n",
    "  if patch:\n",
    "    gratings = np.cos(xx * np.cos(angle + .1) + yy * np.sin(angle + .1))  # phase shift to make it better fit within patch\n",
    "    gratings[gratings < 0] = 0\n",
    "    gratings[gratings > 0] = 1\n",
    "    xcent = gratings.shape[1] * .75\n",
    "    ycent = gratings.shape[0] / 2\n",
    "    xxc, yyc = np.meshgrid(np.arange(0, gratings.shape[1]), np.arange(0, gratings.shape[0]))\n",
    "    icirc = ((xxc - xcent) ** 2 + (yyc - ycent) ** 2) ** 0.5 < wpix / 3 / 2 * res\n",
    "    gratings[~icirc] = 0.5\n",
    "\n",
    "  else:\n",
    "    gratings = np.cos(xx * np.cos(angle) + yy * np.sin(angle))\n",
    "    gratings[gratings < 0] = 0\n",
    "    gratings[gratings > 0] = 1\n",
    "\n",
    "  # Return torch tensor\n",
    "  return torch.tensor(gratings, dtype=torch.float32)\n",
    "\n",
    "\n",
    "class CNN(nn.Module):\n",
    "  \"\"\"Deep convolutional network with one convolutional + pooling layer followed\n",
    "  by one fully connected layer\n",
    "\n",
    "  Args:\n",
    "    h_in (int): height of input image, in pixels (i.e. number of rows)\n",
    "    w_in (int): width of input image, in pixels (i.e. number of columns)\n",
    "\n",
    "  Attributes:\n",
    "    conv (nn.Conv2d): filter weights of convolutional layer\n",
    "    pool (nn.MaxPool2d): max pooling layer\n",
    "    dims (tuple of ints): dimensions of output from pool layer\n",
    "    fc (nn.Linear): weights and biases of fully connected layer\n",
    "    out (nn.Linear): weights and biases of output layer\n",
    "\n",
    "  \"\"\"\n",
    "\n",
    "  def __init__(self, h_in, w_in):\n",
    "    super().__init__()\n",
    "    C_in = 1  # input stimuli have only 1 input channel\n",
    "    C_out = 8  # number of output channels (i.e. of convolutional kernels to convolve the input with)\n",
    "    K = 5  # size of each convolutional kernel\n",
    "    Kpool = 2  # size of patches over which to pool\n",
    "    self.conv = nn.Conv2d(C_in, C_out, kernel_size=K, padding=K//2)  # add padding to ensure that each channel has same dimensionality as input\n",
    "    self.pool = nn.MaxPool2d(Kpool)\n",
    "    self.dims = (C_out, h_in // Kpool, w_in // Kpool)  # dimensions of pool layer output\n",
    "    self.fc = nn.Linear(np.prod(self.dims), 10)  # flattened pool output --> 10D representation\n",
    "    self.out = nn.Linear(10, 1)  # 10D representation --> scalar\n",
    "\n",
    "  def forward(self, x):\n",
    "    \"\"\"Classify grating stimulus as tilted right or left\n",
    "\n",
    "    Args:\n",
    "      x (torch.Tensor): p x 48 x 64 tensor with pixel grayscale values for\n",
    "          each of p stimulus images.\n",
    "\n",
    "    Returns:\n",
    "      torch.Tensor: p x 1 tensor with network outputs for each input provided\n",
    "          in x. Each output should be interpreted as the probability of the\n",
    "          corresponding stimulus being tilted right.\n",
    "\n",
    "    \"\"\"\n",
    "    x = x.unsqueeze(1)  # p x 1 x 48 x 64, add a singleton dimension for the single stimulus channel\n",
    "    x = torch.relu(self.conv(x))  # output of convolutional layer\n",
    "    x = self.pool(x)  # output of pooling layer\n",
    "    x = x.view(-1, np.prod(self.dims))  # flatten pooling layer outputs into a vector\n",
    "    x = torch.relu(self.fc(x))  # output of fully connected layer\n",
    "    x = torch.sigmoid(self.out(x))  # network output\n",
    "    return x\n",
    "\n",
    "\n",
    "def train(net, train_data, train_labels, n_epochs=20, batch_size=100, learning_rate=1e-3, momentum=.99):\n",
    "  \"\"\"Run stochastic gradient descent on binary cross-entropy loss for a given\n",
    "  deep network (cf. appendix for details)\n",
    "\n",
    "  Args:\n",
    "    net (nn.Module): deep network whose parameters to optimize with SGD\n",
    "    train_data (torch.Tensor): n_train x h x w tensor with stimulus gratings\n",
    "    train_labels (torch.Tensor): n_train x 1 tensor with true tilt of each\n",
    "      stimulus grating in train_data, i.e. 1. for right, 0. for left\n",
    "    n_epochs (int): number of times to run SGD through whole training data set\n",
    "    batch_size (int): number of training data samples in each mini-batch\n",
    "    learning_rate (float): learning rate to use for SGD updates\n",
    "    momentum (float): momentum parameter for SGD updates\n",
    "\n",
    "  \"\"\"\n",
    "\n",
    "  # Initialize binary cross-entropy loss function\n",
    "  loss_fn = nn.BCELoss()\n",
    "\n",
    "  # Initialize SGD optimizer with momentum\n",
    "  optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=momentum)\n",
    "\n",
    "  # Placeholder to save loss at each iteration\n",
    "  track_loss = []\n",
    "\n",
    "  # Loop over epochs\n",
    "  for i in range(n_epochs):\n",
    "\n",
    "    # Split up training data into random non-overlapping mini-batches\n",
    "    ishuffle = torch.randperm(train_data.shape[0])  # random ordering of training data\n",
    "    minibatch_data = torch.split(train_data[ishuffle], batch_size)  # split train_data into minibatches\n",
    "    minibatch_labels = torch.split(train_labels[ishuffle], batch_size)  # split train_labels into minibatches\n",
    "\n",
    "    # Loop over mini-batches\n",
    "    for stimuli, tilt in zip(minibatch_data, minibatch_labels):\n",
    "\n",
    "      # Evaluate loss and update network weights\n",
    "      out = net(stimuli)  # predicted probability of tilt right\n",
    "      loss = loss_fn(out, tilt)  # evaluate loss\n",
    "      optimizer.zero_grad()  # clear gradients\n",
    "      loss.backward()  # compute gradients\n",
    "      optimizer.step()  # update weights\n",
    "\n",
    "      # Keep track of loss at each iteration\n",
    "      track_loss.append(loss.item())\n",
    "\n",
    "    # Track progress\n",
    "    if (i + 1) % (n_epochs // 5) == 0:\n",
    "      print(f'epoch {i + 1} | loss on last mini-batch: {loss.item(): .2e}')\n",
    "\n",
    "  print('training done!')\n",
    "\n",
    "\n",
    "def get_hidden_activity(net, stimuli, layer_labels):\n",
    "  \"\"\"Retrieve internal representations of network\n",
    "\n",
    "  Args:\n",
    "    net (nn.Module): deep network\n",
    "    stimuli (torch.Tensor): p x 48 x 64 tensor with stimuli for which to\n",
    "      compute and retrieve internal representations\n",
    "    layer_labels (list): list of strings with labels of each layer for which\n",
    "      to return its internal representations\n",
    "\n",
    "  Returns:\n",
    "    dict: internal representations at each layer of the network, in\n",
    "      numpy arrays. The keys of this dict are the strings in layer_labels.\n",
    "\n",
    "  \"\"\"\n",
    "\n",
    "  # Placeholder\n",
    "  hidden_activity = {}\n",
    "\n",
    "  # Attach 'hooks' to each layer of the network to store hidden\n",
    "  # representations in hidden_activity\n",
    "  def hook(module, input, output):\n",
    "    module_label = list(net._modules.keys())[np.argwhere([module == m for m in net._modules.values()])[0, 0]]\n",
    "    if module_label in layer_labels:  # ignore output layer\n",
    "      hidden_activity[module_label] = output.view(stimuli.shape[0], -1).detach().numpy()\n",
    "  hooks = [layer.register_forward_hook(hook) for layer in net.children()]\n",
    "\n",
    "  # Run stimuli through the network\n",
    "  pred = net(stimuli)\n",
    "\n",
    "  # Remove the hooks\n",
    "  [h.remove() for h in hooks]\n",
    "\n",
    "  return hidden_activity"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "# Section 1: Setting up deep network and neural data\n",
    "\n",
    "In the future sections, we will compare the activity in a deep network, specifically in a CNN, with neural activity. First, we need to understand the task we are using (Section 1.1), train our deep network (Section 1.2), and load in neural data (Section 1.3). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {
     "iopub.execute_input": "2021-05-31T20:50:03.248507Z",
     "iopub.status.busy": "2021-05-31T20:50:03.247985Z",
     "iopub.status.idle": "2021-05-31T20:50:03.270942Z",
     "shell.execute_reply": "2021-05-31T20:50:03.271359Z"
    }
   },
   "outputs": [],
   "source": [
    "#@title Video 1: Deep convolutional network for orientation discrimination\n",
    "from IPython.display import YouTubeVideo\n",
    "video = YouTubeVideo(id=\"KlXtKJCpV4I\", width=854, height=480, fs=1)\n",
    "print(\"Video available at https://youtube.com/watch?v=\" + video.id)\n",
    "video"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Section 1.1: Orientation discrimination task\n",
    "\n",
    "We will build our normative encoding model by optimizing its parameters to solve an orientation discrimination task. \n",
    "\n",
    "The task is to tell whether a given grating stimulus is tilted to the \"right\" or \"left\"; that is, whether its angle relative to the vertical is positive or negative, respectively. We show example stimuli below, which were constructed using the helper function `grating()`.\n",
    "\n",
    "Note that this is a task that we know many mammalian visual systems are capable of solving. It is therefore conceivable that the representations in a deep network model optimized for this task might resemble those in the brain. To test this hypothesis, we will compare the representations of our optimized encoding model to neural activity recorded in response to these very same stimuli, courtesy of [Stringer et al 2019](https://www.biorxiv.org/content/10.1101/679324v2.abstract)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {
     "iopub.execute_input": "2021-05-31T20:50:03.293010Z",
     "iopub.status.busy": "2021-05-31T20:50:03.287540Z",
     "iopub.status.idle": "2021-05-31T20:50:03.806035Z",
     "shell.execute_reply": "2021-05-31T20:50:03.806485Z"
    }
   },
   "outputs": [],
   "source": [
    "#@title\n",
    "#@markdown Execute this cell to plot example stimuli\n",
    "\n",
    "orientations = np.linspace(-90, 90, 5)\n",
    "\n",
    "h = 3\n",
    "n_col = len(orientations)\n",
    "fig, axs = plt.subplots(1, n_col, figsize=(h * n_col, h))\n",
    "\n",
    "h, w  = grating(0).shape  # height and width of stimulus\n",
    "print('stimulus size: %i x %i' % (h, w))\n",
    "\n",
    "for i, ori in enumerate(orientations):\n",
    "  stimulus = grating(ori)\n",
    "  axs[i].set_title(f'{ori: .0f}$^o$')\n",
    "  show_stimulus(stimulus, axs[i])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Section 1.2: A deep network model of orientation discrimination\n",
    "\n",
    "Our goal is to build a model that solves the orientation discrimination task outlined above. The model should take as input a stimulus image and output the probability of that stimulus being tilted right.\n",
    "\n",
    "To do this, we will use a **convolutional neural network (CNN)**, which is the type of network we saw in Tutorial 2. Here, we will use a CNN that performs *two-dimensional* convolutions on the raw stimulus image (which is a 2D matrix of pixels), rather than *one-dimensional* convolutions on a categorical 1D vector representation of the stimulus. CNNs are commonly used for image processing. \n",
    "\n",
    "The particular CNN we will use here has two layers:\n",
    "1. a *convolutional layer*, which convolves the images with a set of filters\n",
    "2. a *fully connected layer*, which transforms the output of this convolution into a 10-dimensional representation\n",
    "\n",
    "Finally, a set of output weights transforms this 10-dimensional representation into a single scalar $p$, denoting the predicted probability of the input stimulus being tilted right. \n",
    "\n",
    "<p align=\"center\">\n",
    "  <img src=\"https://github.com/NeuromatchAcademy/course-content/blob/master/tutorials/static/conv-network.png?raw=true\" width=\"450\" />\n",
    "</p>\n",
    "\n",
    "See Bonus Section 1 for in-depth instructions for how to code up such a network in PyTorch. For now, however, we'll leave these details aside and focus on training this network and analyzing its internal representations.\n",
    "\n",
    "Run the next cell to train such a network to solve this task. After initializing our CNN model, it builds a dataset of oriented grating stimuli to use for training it. These are then passed into a function called `train()` that uses SGD to optimize the model's parameters, taking similar arguments as the `train()` function we wrote in Tutorial 1.\n",
    "\n",
    "Note that it may take ~30 seconds for the training to complete."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-31T20:50:03.810764Z",
     "iopub.status.busy": "2021-05-31T20:50:03.810309Z",
     "iopub.status.idle": "2021-05-31T20:50:03.814280Z",
     "shell.execute_reply": "2021-05-31T20:50:03.813869Z"
    }
   },
   "outputs": [],
   "source": [
    "help(train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-31T20:50:03.819092Z",
     "iopub.status.busy": "2021-05-31T20:50:03.818609Z",
     "iopub.status.idle": "2021-05-31T20:50:08.849129Z",
     "shell.execute_reply": "2021-05-31T20:50:08.848605Z"
    }
   },
   "outputs": [],
   "source": [
    "# Set random seeds for reproducibility\n",
    "np.random.seed(12)\n",
    "torch.manual_seed(12)\n",
    "\n",
    "# Initialize CNN model\n",
    "net = CNN(h, w)\n",
    "\n",
    "# Build training set to train it on\n",
    "n_train = 1000  # size of training set\n",
    "\n",
    "# sample n_train random orientations between -90 and +90 degrees\n",
    "ori = (np.random.rand(n_train) - 0.5) * 180\n",
    "\n",
    "# build orientated grating stimuli\n",
    "stimuli = torch.stack([grating(i) for i in ori])\n",
    "\n",
    "# stimulus tilt: 1. if tilted right, 0. if tilted left, as a column vector\n",
    "tilt = torch.tensor(ori > 0).type(torch.float).unsqueeze(-1)\n",
    "\n",
    "# Train model\n",
    "train(net, stimuli, tilt)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Section 1.3: Load data\n",
    "\n",
    "In the next cell, we provide code for loading in some data from [this paper](https://www.biorxiv.org/content/10.1101/679324v2.abstract), which contains the responses of about ~20,000 neurons in mouse primary visual cortex to grating stimuli like those used to train our network (this is the same data used in Tutorial 1). These data are stored in two variables:\n",
    "* `resp_v1` is a matrix where each row contains the responses of all neurons to a single stimulus.\n",
    "* `ori` is a vector with the orientations of each stimulus, in degrees. As in the above convention, negative angles denote stimuli tilted to the left and positive angles denote stimuli tilted to the right.\n",
    "\n",
    "We will then extract our deep CNN model's representations of these same stimuli (i.e. oriented gratings with the orientations in `ori`). We will run the same stimuli through our CNN model and use the helper function `get_hidden_activity()` to store the model's internal representations. The output of this function is a Python `dict`, which contains a matrix of population responses (just like `resp_v1`) for each layer of the network specified by the `layer_labels` argument. We'll focus on looking at the representations in\n",
    "* the output of the first convolutional layer, stored in the model as `'pool'` (see Bonus Section 1 for the details of the CNN architecture to understand why it's called this way)\n",
    "* the 10-dimensional output of the fully connected layer, stored in the model as `'fc'`\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "both",
    "execution": {
     "iopub.execute_input": "2021-05-31T20:50:08.853895Z",
     "iopub.status.busy": "2021-05-31T20:50:08.853427Z",
     "iopub.status.idle": "2021-05-31T20:50:09.342954Z",
     "shell.execute_reply": "2021-05-31T20:50:09.342511Z"
    }
   },
   "outputs": [],
   "source": [
    "# Load mouse V1 data\n",
    "resp_v1, ori = load_data()\n",
    "\n",
    "# Extract model internal representations of each stimulus in the V1 data\n",
    "# construct grating stimuli for each orientation presented in the V1 data\n",
    "stimuli = torch.stack([grating(a.item()) for a in ori])\n",
    "layer_labels = ['pool', 'fc']\n",
    "resp_model = get_hidden_activity(net, stimuli, layer_labels)\n",
    "\n",
    "# Aggregate all responses into one dict\n",
    "resp_dict = {}\n",
    "resp_dict['V1 data'] = resp_v1\n",
    "for k, v in resp_model.items():\n",
    "  label = f\"model\\n'{k}' layer\"\n",
    "  resp_dict[label] = v"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "# Section 2: Quantitative comparisons of CNNs and neural activity\n",
    "\n",
    "Let's now analyze the internal representations of our deep CNN model of orientation discrimination and compare them to population responses in mouse primary visual cortex. \n",
    "\n",
    "In this section, we'll try to quantitatively compare CNN and primary visual cortex representations. In the next section, we will visualize their representations and get some intuition for their structure.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {
     "iopub.execute_input": "2021-05-31T20:50:09.348539Z",
     "iopub.status.busy": "2021-05-31T20:50:09.348004Z",
     "iopub.status.idle": "2021-05-31T20:50:09.374384Z",
     "shell.execute_reply": "2021-05-31T20:50:09.374890Z"
    }
   },
   "outputs": [],
   "source": [
    "#@title Video 2: Quantitative comparisons of CNNs and neural activity\n",
    "from IPython.display import YouTubeVideo\n",
    "video = YouTubeVideo(id=\"2Jbk7jFBvbU\", width=854, height=480, fs=1)\n",
    "print(\"Video available at https://youtube.com/watch?v=\" + video.id)\n",
    "video"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We noticed above some similarities and differences between the population responses in mouse primary visual cortex and in different layers in our model. Let's now try to quantify this.\n",
    "\n",
    "To do this, we'll use a technique called [**Representational Similarity Analysis**](https://www.frontiersin.org/articles/10.3389/neuro.06.004.2008/full?utm_source=FWEB&utm_medium=NBLOG&utm_campaign=ECO_10YA_top-research). The idea is to look at the similarity structure between representations of different stimuli. We can say that a brain area and a model use a similar representational scheme if stimuli that are represented (dis)similarly in the brain are represented (dis)similarly in the model as well.\n",
    "\n",
    " "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Section 2.1: Representational dissimilarity matrix (RDM)\n",
    "\n",
    "\n",
    "To quantify this, we begin by computing the **representational dissimilarity matrix (RDM)** for the mouse V1 data and each model layer. This matrix, which we'll call $\\mathbf{M}$, is computed as one minus the correlation coefficients between population responses to each stimulus. We can  efficiently compute this by using the $z$-scored responses (see Bonus Section 3 for explanation). In particular, the full matrix can be computed as:\n",
    "\\begin{gather}\n",
    "  \\mathbf{M} = 1 - \\frac{1}{N} \\mathbf{ZZ}^T \\\\\n",
    "\\end{gather}\n",
    "\n",
    "where $\\mathbf{Z}$ is the z-scored responses and N is the number of neurons (or units).\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Coding Exercise 2.1: Compute RDMs\n",
    "\n",
    "Complete the function `RDM()` for computing the RDM for a given set of population responses to each stimulus. Use the above formula in terms of $z$-scored population responses. You can use the helper function `zscore()` to compute the matrix of $z$-scored responses.\n",
    "\n",
    "The subsequent cell uses this function to plot the RDM of the population responses in the V1 data and in each layer of our model CNN.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-31T20:50:09.380824Z",
     "iopub.status.busy": "2021-05-31T20:50:09.380346Z",
     "iopub.status.idle": "2021-05-31T20:50:09.512071Z",
     "shell.execute_reply": "2021-05-31T20:50:09.511525Z"
    }
   },
   "outputs": [],
   "source": [
    "def RDM(resp):\n",
    "  \"\"\"Compute the representational dissimilarity matrix (RDM)\n",
    "\n",
    "  Args:\n",
    "    resp (ndarray): S x N matrix with population responses to\n",
    "      each stimulus in each row\n",
    "\n",
    "  Returns:\n",
    "    ndarray: S x S representational dissimilarity matrix\n",
    "  \"\"\"\n",
    "  #########################################################\n",
    "  ## TO DO for students: compute representational dissimilarity matrix\n",
    "  # Fill out function and remove\n",
    "  raise NotImplementedError(\"Student exercise: complete function RDM\")\n",
    "  #########################################################\n",
    "\n",
    "  # z-score responses to each stimulus\n",
    "  zresp = ...\n",
    "\n",
    "  # Compute RDM\n",
    "  RDM = ...\n",
    "\n",
    "  return RDM\n",
    "\n",
    "\n",
    "# Compute RDMs\n",
    "rdm_dict = {label: RDM(resp) for label, resp in resp_dict.items()}\n",
    "\n",
    "# Plot RDMs\n",
    "plot_multiple_rdm(rdm_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-31T20:50:09.516936Z",
     "iopub.status.busy": "2021-05-31T20:50:09.516407Z",
     "iopub.status.idle": "2021-05-31T20:50:10.252127Z",
     "shell.execute_reply": "2021-05-31T20:50:10.252512Z"
    }
   },
   "outputs": [],
   "source": [
    "# to_remove solution\n",
    "def RDM(resp):\n",
    "  \"\"\"Compute the representational dissimilarity matrix (RDM)\n",
    "\n",
    "  Args:\n",
    "    resp (ndarray): S x N matrix with population responses to\n",
    "      each stimulus in each row\n",
    "\n",
    "  Returns:\n",
    "    ndarray: S x S representational dissimilarity matrix\n",
    "  \"\"\"\n",
    "\n",
    "  # z-score responses to each stimulus\n",
    "  zresp = zscore(resp, axis=1)\n",
    "\n",
    "  # Compute RDM\n",
    "  RDM = 1 - (zresp @ zresp.T) / zresp.shape[1]\n",
    "\n",
    "  return RDM\n",
    "\n",
    "# Compute RDMs\n",
    "rdm_dict = {label: RDM(resp) for label, resp in resp_dict.items()}\n",
    "\n",
    "# Plot RDMs\n",
    "with plt.xkcd():\n",
    "  plot_multiple_rdm(rdm_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {
     "iopub.execute_input": "2021-05-31T20:50:10.256986Z",
     "iopub.status.busy": "2021-05-31T20:50:10.256447Z",
     "iopub.status.idle": "2021-05-31T20:50:10.282169Z",
     "shell.execute_reply": "2021-05-31T20:50:10.282545Z"
    }
   },
   "outputs": [],
   "source": [
    "#@title Video 3: Coding Exercise 2.1 solution discussion\n",
    "from IPython.display import YouTubeVideo\n",
    "video = YouTubeVideo(id=\"otzR-KXDjus\", width=854, height=480, fs=1)\n",
    "print(\"Video available at https://youtube.com/watch?v=\" + video.id)\n",
    "video"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Section 2.2: Determing representation similarity\n",
    "\n",
    "\n",
    "To quantify how similar the representations are, we can simply correlate their dissimilarity matrices. For this, we'll again use the correlation coefficient. Note that dissimilarity matrices are symmetric ($M_{ss'} = M_{s's}$), so we should only use the off-diagonal terms on one side of the diagonal when computing this correlation to avoid overcounting. Moreover, we should leave out the diagonal terms, which are always equal to 0, so will always be perfectly correlated across any pair of RDM's.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Coding Exercise 2.2: Correlate RDMs\n",
    "\n",
    "Complete the function `correlate_rdms()` below that computes this correlation. The code for extracting the off-diagonal terms is provided.\n",
    "\n",
    "We will then use function to compute the correlation between the RDM's for each layer of our model CNN and that of the V1 data. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-31T20:50:10.288285Z",
     "iopub.status.busy": "2021-05-31T20:50:10.287011Z",
     "iopub.status.idle": "2021-05-31T20:50:10.288901Z",
     "shell.execute_reply": "2021-05-31T20:50:10.289290Z"
    }
   },
   "outputs": [],
   "source": [
    "def correlate_rdms(rdm1, rdm2):\n",
    "  \"\"\"Correlate off-diagonal elements of two RDM's\n",
    "\n",
    "  Args:\n",
    "    rdm1 (np.ndarray): S x S representational dissimilarity matrix\n",
    "    rdm2 (np.ndarray): S x S representational dissimilarity matrix to\n",
    "      correlate with rdm1\n",
    "\n",
    "  Returns:\n",
    "    float: correlation coefficient between the off-diagonal elements\n",
    "      of rdm1 and rdm2\n",
    "\n",
    "  \"\"\"\n",
    "\n",
    "  # Extract off-diagonal elements of each RDM\n",
    "  ioffdiag = np.triu_indices(rdm1.shape[0], k=1)  # indices of off-diagonal elements\n",
    "  rdm1_offdiag = rdm1[ioffdiag]\n",
    "  rdm2_offdiag = rdm2[ioffdiag]\n",
    "\n",
    "  #########################################################\n",
    "  ## TO DO for students: compute correlation coefficient\n",
    "  # Fill out function and remove\n",
    "  raise NotImplementedError(\"Student exercise: complete correlate rdms\")\n",
    "  #########################################################\n",
    "  corr_coef = np.corrcoef(..., ...)[0,1]\n",
    "\n",
    "  return corr_coef\n",
    "\n",
    "\n",
    "# Split RDMs into V1 responses and model responses\n",
    "rdm_model = rdm_dict.copy()\n",
    "rdm_v1 = rdm_model.pop('V1 data')\n",
    "\n",
    "# Correlate off-diagonal terms of dissimilarity matrices\n",
    "# Uncomment below to test your function\n",
    "# rdm_sim = {label: correlate_rdms(rdm_v1, rdm) for label, rdm in rdm_model.items()}\n",
    "# plot_rdm_rdm_correlations(rdm_sim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-31T20:50:10.294175Z",
     "iopub.status.busy": "2021-05-31T20:50:10.293714Z",
     "iopub.status.idle": "2021-05-31T20:50:10.539208Z",
     "shell.execute_reply": "2021-05-31T20:50:10.538732Z"
    }
   },
   "outputs": [],
   "source": [
    "# to_remove solution\n",
    "def correlate_rdms(rdm1, rdm2):\n",
    "  \"\"\"Correlate off-diagonal elements of two RDM's\n",
    "\n",
    "  Args:\n",
    "    rdm1 (np.ndarray): S x S representational dissimilarity matrix\n",
    "    rdm2 (np.ndarray): S x S representational dissimilarity matrix to\n",
    "      correlate with rdm1\n",
    "\n",
    "  Returns:\n",
    "    float: correlation coefficient between the off-diagonal elements\n",
    "      of rdm1 and rdm2\n",
    "\n",
    "  \"\"\"\n",
    "\n",
    "  # Extract off-diagonal elements of each RDM\n",
    "  ioffdiag = np.triu_indices(rdm1.shape[0], k=1)  # indices of off-diagonal elements\n",
    "  rdm1_offdiag = rdm1[ioffdiag]\n",
    "  rdm2_offdiag = rdm2[ioffdiag]\n",
    "\n",
    "  corr_coef = np.corrcoef(rdm1_offdiag, rdm2_offdiag)[0,1]\n",
    "\n",
    "  return corr_coef\n",
    "\n",
    "\n",
    "# Split RDMs into V1 responses and model responses\n",
    "rdm_model = rdm_dict.copy()\n",
    "rdm_v1 = rdm_model.pop('V1 data')\n",
    "\n",
    "# Correlate off-diagonal terms of dissimilarity matrices\n",
    "rdm_sim = {label: correlate_rdms(rdm_v1, rdm) for label, rdm in rdm_model.items()}\n",
    "with plt.xkcd():\n",
    "  plot_rdm_rdm_correlations(rdm_sim)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "According to this metric, which layer's representations most resemble those in the data? Does this agree with your intuitions from exercise 3?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Section 2.3: Further understanding RDMs\n",
    "\n",
    "To better understand how these correlations in RDM's arise, we can try plotting individual rows of the RDM matrix. The resulting curves show the similarity of the responses to each stimulus with that to one specific stimulus."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Coding Exercise 2.3: Plot rows of RDM\n",
    "\n",
    "\n",
    "Complete the `plot_rdm_rows()` function below for plotting the rows of the model and data RDM's. We will then plot a few specified rows. Do these curves explain the correlation (or lack thereof) in RDM's you saw in the previous exercise?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-31T20:50:10.546751Z",
     "iopub.status.busy": "2021-05-31T20:50:10.545809Z",
     "iopub.status.idle": "2021-05-31T20:50:10.547928Z",
     "shell.execute_reply": "2021-05-31T20:50:10.547513Z"
    }
   },
   "outputs": [],
   "source": [
    "def plot_rdm_rows(ori_list, rdm_dict, rdm_oris):\n",
    "  \"\"\"Plot the dissimilarity of response to each stimulus with response to one\n",
    "  specific stimulus\n",
    "\n",
    "  Args:\n",
    "    ori_list (list of float): plot dissimilarity with response to stimulus with\n",
    "      orientations closest to each value in this list\n",
    "    rdm_dict (dict): RDM's from which to extract dissimilarities\n",
    "    rdm_oris (np.ndarray): orientations corresponding to each row/column of RDMs\n",
    "    in rdm_dict\n",
    "\n",
    "  \"\"\"\n",
    "  n_col = len(ori_list)\n",
    "  f, axs = plt.subplots(1, n_col, figsize=(4 * n_col, 4), sharey=True)\n",
    "\n",
    "  # Get index of orientation closest to ori_plot\n",
    "  for ax, ori_plot in zip(axs, ori_list):\n",
    "    iori = np.argmin(np.abs(rdm_oris - ori_plot))\n",
    "\n",
    "    ######################################################################\n",
    "    # TODO: plot dissimilarity curves in each RDM and remove the error\n",
    "    raise NotImplementedError(\"Student exercise: complete plot_rdm_rows\")\n",
    "    ######################################################################\n",
    "\n",
    "    # Plot dissimilarity curves in each RDM\n",
    "    for label, rdm in rdm_dict.items():\n",
    "      ax.plot(..., ..., label=label)\n",
    "\n",
    "    # Draw vertical line at stimulus we are plotting dissimilarity w.r.t.\n",
    "    ax.axvline(rdm_oris[iori], color=\".7\", zorder=-1)\n",
    "\n",
    "    # Label axes\n",
    "    ax.set_title(f'Dissimilarity with response\\nto {ori_plot: .0f}$^o$ stimulus')\n",
    "    ax.set_xlabel('Stimulus orientation ($^o$)')\n",
    "\n",
    "  axs[0].set_ylabel('Dissimilarity')\n",
    "  axs[-1].legend(loc=\"upper left\", bbox_to_anchor=(1, 1))\n",
    "\n",
    "\n",
    "ori_list = [-75, -25, 25, 75]\n",
    "\n",
    "# Uncomment to test your function\n",
    "# plot_rdm_rows(ori_list, rdm_dict, ori)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-31T20:50:10.584781Z",
     "iopub.status.busy": "2021-05-31T20:50:10.567274Z",
     "iopub.status.idle": "2021-05-31T20:50:11.404218Z",
     "shell.execute_reply": "2021-05-31T20:50:11.404700Z"
    }
   },
   "outputs": [],
   "source": [
    "# to_remove solution\n",
    "def plot_rdm_rows(ori_list, rdm_dict, rdm_oris):\n",
    "  \"\"\"Plot the dissimilarity of response to each stimulus with response to one\n",
    "  specific stimulus\n",
    "\n",
    "  Args:\n",
    "    ori_list (list of float): plot dissimilarity with response to stimulus with\n",
    "      orientations closest to each value in this list\n",
    "    rdm_dict (dict): RDM's from which to extract dissimilarities\n",
    "    rdm_oris (np.ndarray): orientations corresponding to each row/column of RDMs\n",
    "    in rdm_dict\n",
    "\n",
    "  \"\"\"\n",
    "  n_col = len(ori_list)\n",
    "  f, axs = plt.subplots(1, n_col, figsize=(4 * n_col, 4), sharey=True)\n",
    "\n",
    "  # Get index of orientation closest to ori_plot\n",
    "  for ax, ori_plot in zip(axs, ori_list):\n",
    "    iori = np.argmin(np.abs(rdm_oris - ori_plot))\n",
    "\n",
    "    # Plot dissimilarity curves in each RDM\n",
    "    for label, rdm in rdm_dict.items():\n",
    "      ax.plot(rdm_oris, rdm[iori, :], label=label)\n",
    "\n",
    "    # Draw vertical line at stimulus we are plotting dissimilarity w.r.t.\n",
    "    ax.axvline(rdm_oris[iori], color=\".7\", zorder=-1)\n",
    "\n",
    "    # Label axes\n",
    "    ax.set_title(f'Dissimilarity with response\\nto {ori_plot: .0f}$^o$ stimulus')\n",
    "    ax.set_xlabel('Stimulus orientation ($^o$)')\n",
    "\n",
    "  axs[0].set_ylabel('Dissimilarity')\n",
    "  axs[-1].legend(loc=\"upper left\", bbox_to_anchor=(1, 1))\n",
    "\n",
    "\n",
    "ori_list = [-75, -25, 25, 75]\n",
    "with plt.xkcd():\n",
    "  plot_rdm_rows(ori_list, rdm_dict, ori)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "# Section 3: Qualitative comparisons of CNNs and neural activity\n",
    "\n",
    "To visualize the representations in the data and in each of these model layers, we'll use two classic techniques from systems neuroscience:\n",
    "\n",
    "1. **tuning curves**: plotting the response of single neurons (or units, in the case of the deep network) as a function of the stimulus orientation\n",
    "\n",
    "2. **dimensionality reduction**: plotting full population responses to each stimulus in two dimensions via dimensionality reduction. We'll use the non-linear dimensionality reduction technique t-SNE for this."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Section 3.1: Tuning curves"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Below, we show some example tuning curves for different neurons and units in the CNN we trained above. How are the single neuron responses similar/different between the model and the data? Try running this cell multiple times to get an idea of shared properties in the tuning curves of the neurons within each population.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {
     "iopub.execute_input": "2021-05-31T20:50:11.426147Z",
     "iopub.status.busy": "2021-05-31T20:50:11.424146Z",
     "iopub.status.idle": "2021-05-31T20:50:12.031736Z",
     "shell.execute_reply": "2021-05-31T20:50:12.032130Z"
    }
   },
   "outputs": [],
   "source": [
    "#@title\n",
    "#@markdown Execute this cell to visualize tuning curves\n",
    "\n",
    "fig, axs = plt.subplots(1, len(resp_dict), figsize=(len(resp_dict) * 6, 6))\n",
    "\n",
    "for i, (label, resp) in enumerate(resp_dict.items()):\n",
    "\n",
    "  ax = axs[i]\n",
    "  ax.set_title('%s responses' % label)\n",
    "\n",
    "  # Pick three random neurons whose tuning curves to plot\n",
    "  ineurons = np.random.choice(resp.shape[1], 3, replace=False)\n",
    "\n",
    "  # Plot tuning curves of ineurons\n",
    "  ax.plot(ori, resp[:, ineurons])\n",
    "\n",
    "  ax.set_xticks(np.linspace(-90, 90, 5))\n",
    "  ax.set_xlabel('stimulus orientation')\n",
    "  ax.set_ylabel('neural response')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Section 3.2: Dimensionality reduction of representations\n",
    "\n",
    "We can visualize a dimensionality-reduced version of the internal representations of the mouse primary visual cortex or CNN internal representations in order to potentially uncover informative structure. Here, we use PCA to reduce the dimensionality to 20 dimensions, and then use tSNE to further reduce dimensionality to 2 dimensions. We use the first step of PCA so that tSNE runs faster."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-31T20:50:12.059107Z",
     "iopub.status.busy": "2021-05-31T20:50:12.040846Z",
     "iopub.status.idle": "2021-05-31T20:50:15.081618Z",
     "shell.execute_reply": "2021-05-31T20:50:15.082151Z"
    }
   },
   "outputs": [],
   "source": [
    "def plot_resp_lowd(resp_dict):\n",
    "    \"\"\"Plot a low-dimensional representation of each dataset in resp_dict.\"\"\"\n",
    "    n_col = len(resp_dict)\n",
    "    fig, axs = plt.subplots(1, n_col, figsize=(4.5 * len(resp_dict), 4.5))\n",
    "    for i, (label, resp) in enumerate(resp_dict.items()):\n",
    "\n",
    "      ax = axs[i]\n",
    "      ax.set_title('%s responses' % label)\n",
    "\n",
    "      # First do PCA to reduce dimensionality to 20 dimensions so that tSNE is faster\n",
    "      resp_lowd = PCA(n_components=min(20, resp.shape[1])).fit_transform(resp)\n",
    "\n",
    "      # Then do tSNE to reduce dimensionality to 2 dimensions\n",
    "      resp_lowd = TSNE(n_components=2).fit_transform(resp_lowd)\n",
    "\n",
    "      # Plot dimensionality-reduced population responses 'resp_lowd'\n",
    "      # on 2D axes, with each point colored by stimulus orientation\n",
    "      x, y = resp_lowd[:, 0], resp_lowd[:, 1]\n",
    "      pts = ax.scatter(x, y, c=ori, cmap='twilight', vmin=-90, vmax=90)\n",
    "      fig.colorbar(pts, ax=ax, ticks=np.linspace(-90, 90, 5), label='Stimulus orientation')\n",
    "\n",
    "      ax.set_xlabel('Dimension 1')\n",
    "      ax.set_ylabel('Dimension 2')\n",
    "      ax.set_xticks([])\n",
    "      ax.set_yticks([])\n",
    "\n",
    "with plt.xkcd():\n",
    "  plot_resp_lowd(resp_dict)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Think! 3.2: Visualizing reduced dimensionality representations\n",
    "\n",
    "Interpret the figure above. Why do these representations look the way they do? Here are a few specific questions to think about:\n",
    "  * How are the population responses similar/different between the model and the data? Can you explain these population-level responses from the single neuron responses seen in the previous exercise, or vice-versa?\n",
    "  * How do the representations in the different layers of the model differ, and how does this relate to the orientation discrimination task the model was optimized for?\n",
    "  * Which layer of our deep network encoding model most closely resembles the V1 data?\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-31T20:50:15.086785Z",
     "iopub.status.busy": "2021-05-31T20:50:15.086251Z",
     "iopub.status.idle": "2021-05-31T20:50:15.091229Z",
     "shell.execute_reply": "2021-05-31T20:50:15.090870Z"
    }
   },
   "outputs": [],
   "source": [
    "# to_remove explanation\n",
    "\n",
    "\"\"\"\n",
    "The single unit activations of the 'pool' layer in the model have peaks at various\n",
    "orientations, but they appear to have more peaks than the tuning curves from the\n",
    "original data do. When we look at the population level responses we see that they\n",
    "are not quite as smooth across orientations in the t-SNE embedding, which is likely\n",
    "due to the fact that the 'pool' layer activations do not have localized responses\n",
    "to orientations like the neural data.\n",
    "\n",
    "The representations in the fully-connected 'fc' layer appear to be much more\n",
    "clustered than the 'pool' layer. Stimuli which correspond to the same choice are\n",
    "clustered together. It seems like the 'pool' layer is still working hard to\n",
    "represent information about the orientation of the stimulus (much like the V1\n",
    "population), whereas the 'fc' layer only cares about tilt category, representing\n",
    "all the stimuli with the same category in a similar way regardless of their different\n",
    "orientations.\n",
    "\n",
    "From this analysis, it appears that the 'pool' layer is more similar to the neural\n",
    "data at the population level.\n",
    "\n",
    "\"\"\";"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "# Summary\n",
    "\n",
    "In this notebook, we learned \n",
    "* how to use deep learning to build a normative encoding model of the visual system\n",
    "* how to use RSA to evaluate how the model's representations match to those in the brain\n",
    "\n",
    "Our approach was to optimize a deep convolutional network to solve an orientation discrimination task. But note that many other approaches could have been taken.\n",
    "\n",
    "Firstly, there are many other \"normative\" ways to solve this orientation discrimination task. We could have used different neural network architectures, or even used a completely different algorithm that didn't involve a neural network at all, but instead used other kinds of image transformations (e.g. Fourier transforms). Neural network approaches, however, are special in that they explicitly uses abstract distributed representations to compute, which feels a lot closer to the kinds of algorithms the brain uses. *Convolutional* neural networks in particular are well-suited for building normative models of the visual system.\n",
    "\n",
    "Secondly, our choice of visual task was mostly arbitrary. For example, we could have trained our network to directly estimate the orientation of the stimulus, rather than just discriminating between two classes of tilt. Or, we could have trained the network to perform a more naturalistic task, such as recognizing the rotation of an arbitrary image. Or we could try a task like object recognition. Is this something that mice compute in their visual cortex?\n",
    "\n",
    "Training on different tasks could lead to different representations of the oriented grating stimuli, which might match the observed V1 representations better or worse."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "# Bonus"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Bonus Section 1: Building CNN's with PyTorch\n",
    "\n",
    "Here we walk through building the different types of layers in a CNN using PyTorch, culminating in the CNN model used above."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Bonus Section 1.1: Fully connected layers\n",
    "\n",
    "In a fully connected layer, each unit computes a weighted sum over all the input units and applies a non-linear function to this weighted sum. You have used such layers many times already in parts 1 and 2. As you have already seen, these are implemented in PyTorch using the `nn.Linear` class.\n",
    "\n",
    "  See the next cell for code for constructing a deep network with one fully connected layer that will classify an input image as being tilted left or right. Specifically, its output is the predicted probability of the input image being tilted right. To ensure that its output is a probability (i.e. a number between 0 and 1), we use a sigmoid activation function to squash the output into this range (implemented with `torch.sigmoid()`)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-31T20:50:15.096598Z",
     "iopub.status.busy": "2021-05-31T20:50:15.095903Z",
     "iopub.status.idle": "2021-05-31T20:50:15.097735Z",
     "shell.execute_reply": "2021-05-31T20:50:15.097341Z"
    }
   },
   "outputs": [],
   "source": [
    "class FC(nn.Module):\n",
    "  \"\"\"Deep network with one fully connected layer\n",
    "\n",
    "    Args:\n",
    "      h_in (int): height of input image, in pixels (i.e. number of rows)\n",
    "      w_in (int): width of input image, in pixels (i.e. number of columns)\n",
    "\n",
    "    Attributes:\n",
    "      fc (nn.Linear): weights and biases of fully connected layer\n",
    "      out (nn.Linear): weights and biases of output layer\n",
    "\n",
    "    \"\"\"\n",
    "\n",
    "  def __init__(self, h_in, w_in):\n",
    "    super().__init__()\n",
    "    self.dims = h_in * w_in  # dimensions of flattened input\n",
    "    self.fc = nn.Linear(self.dims, 10)  # flattened input image --> 10D representation\n",
    "    self.out = nn.Linear(10, 1)  # 10D representation --> scalar\n",
    "\n",
    "  def forward(self, x):\n",
    "    \"\"\"Classify grating stimulus as tilted right or left\n",
    "\n",
    "    Args:\n",
    "      x (torch.Tensor): p x 48 x 64 tensor with pixel grayscale values for\n",
    "          each of p stimulus images.\n",
    "\n",
    "    Returns:\n",
    "      torch.Tensor: p x 1 tensor with network outputs for each input provided\n",
    "          in x. Each output should be interpreted as the probability of the\n",
    "          corresponding stimulus being tilted right.\n",
    "\n",
    "    \"\"\"\n",
    "    x = x.view(-1, self.dims)  # flatten each input image into a vector\n",
    "    x = torch.relu(self.fc(x))  # output of fully connected layer\n",
    "    x = torch.sigmoid(self.out(x))  # network output\n",
    "    return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Bonus Section 1.2: Convolutional layers\n",
    "\n",
    "In a convolutional layer, each unit computes a weighted sum over a two-dimensional $K \\times K$ patch of inputs. As we saw in part 2, the units are arranged in **channels** (see figure below), whereby units in the same channel compute the same weighted sum over different parts of the input, using the weights of that channel's **convolutional filter (or kernel)**. The output of a convolutional layer is thus a three-dimensional tensor of shape $C^{out} \\times H \\times W$, where $C^{out}$ is the number of channels (i.e. the number of convolutional filters/kernels), and $H$ and $W$ are the height and width of the input.\n",
    "\n",
    "  <p align=\"center\">\n",
    "    <img src=\"https://github.com/NeuromatchAcademy/course-content/blob/master/tutorials/static/convnet.png?raw=true\" width=\"350\" />\n",
    "  </p>\n",
    "\n",
    "Such layers can be implemented in Python using the PyTorch class `nn.Conv2d as we saw in tutorial 2 (documentation [here](https://pytorch.org/docs/master/generated/torch.nn.Conv2d.html)).\n",
    "  \n",
    "See the next cell for code incorporating a convolutional layer with 8 convolutional filters of size 5 $\\times$ 5 into our above fully connected network. Note that we have to flatten the multi-channel output in order to pass it on to the fully connected layer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-31T20:50:15.105641Z",
     "iopub.status.busy": "2021-05-31T20:50:15.104720Z",
     "iopub.status.idle": "2021-05-31T20:50:15.106360Z",
     "shell.execute_reply": "2021-05-31T20:50:15.106736Z"
    }
   },
   "outputs": [],
   "source": [
    "class ConvFC(nn.Module):\n",
    "  \"\"\"Deep network with one convolutional layer and one fully connected layer\n",
    "\n",
    "  Args:\n",
    "    h_in (int): height of input image, in pixels (i.e. number of rows)\n",
    "    w_in (int): width of input image, in pixels (i.e. number of columns)\n",
    "\n",
    "  Attributes:\n",
    "    conv (nn.Conv2d): filter weights of convolutional layer\n",
    "    dims (tuple of ints): dimensions of output from conv layer\n",
    "    fc (nn.Linear): weights and biases of fully connected layer\n",
    "    out (nn.Linear): weights and biases of output layer\n",
    "\n",
    "  \"\"\"\n",
    "\n",
    "  def __init__(self, h_in, w_in):\n",
    "    super().__init__()\n",
    "    C_in = 1  # input stimuli have only 1 input channel\n",
    "    C_out = 8  # number of output channels (i.e. of convolutional kernels to convolve the input with)\n",
    "    K = 5  # size of each convolutional kernel (should be odd number for the padding to work as expected)\n",
    "    self.conv = nn.Conv2d(C_in, C_out, kernel_size=K, padding=K//2)  # add padding to ensure that each channel has same dimensionality as input\n",
    "    self.dims = (C_out, h_in, C_out)  # dimensions of conv layer output\n",
    "    self.fc = nn.Linear(np.prod(self.dims), 10)  # flattened conv output --> 10D representation\n",
    "    self.out = nn.Linear(10, 1)  # 10D representation --> scalar\n",
    "\n",
    "  def forward(self, x):\n",
    "    \"\"\"Classify grating stimulus as tilted right or left\n",
    "\n",
    "    Args:\n",
    "      x (torch.Tensor): p x 48 x 64 tensor with pixel grayscale values for\n",
    "          each of p stimulus images.\n",
    "\n",
    "    Returns:\n",
    "      torch.Tensor: p x 1 tensor with network outputs for each input provided\n",
    "          in x. Each output should be interpreted as the probability of the\n",
    "          corresponding stimulus being tilted right.\n",
    "\n",
    "    \"\"\"\n",
    "    x = x.unsqueeze(1)  # p x 1 x 48 x 64, add a singleton dimension for the single stimulus channel\n",
    "    x = torch.relu(self.conv(x))  # output of convolutional layer\n",
    "    x = x.view(-1, np.prod(self.dims))  # flatten convolutional layer outputs into a vector\n",
    "    x = torch.relu(self.fc(x))  # output of fully connected layer\n",
    "    x = torch.sigmoid(self.out(x))  # network output\n",
    "    return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Bonus Section 1.3: Max pooling layers\n",
    "\n",
    "In a max pooling layer, each unit computes the maximum over a small two-dimensional $K^{pool} \\times K^{pool}$ patch of inputs. Given a multi-channel input of dimensions $C \\times H \\times W$, the output of a max pooling layer has dimensions $C \\times H^{out} \\times W^{out}$, where:\n",
    "\\begin{align}\n",
    "  H^{out} &= \\left\\lfloor \\frac{H}{K^{pool}} \\right\\rfloor\\\\\n",
    "  W^{out} &= \\left\\lfloor \\frac{W}{K^{pool}} \\right\\rfloor\n",
    "\\end{align}\n",
    "where $\\lfloor\\cdot\\rfloor$ denotes rounding down to the nearest integer below (i.e. floor division `//` in Python).\n",
    "\n",
    "  Max pooling layers can be implemented with the PyTorch `nn.MaxPool2d` class, which takes as a single argument the size $K^{pool}$ of the pooling patch. See the next cell for an example, which builds upon the previous example by adding in a max pooling layer just after the convolutional layer. Note again that we need to calculate the dimensions of its output in order to set the dimensions of the subsequent fully connected layer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-31T20:50:15.114548Z",
     "iopub.status.busy": "2021-05-31T20:50:15.113340Z",
     "iopub.status.idle": "2021-05-31T20:50:15.115097Z",
     "shell.execute_reply": "2021-05-31T20:50:15.115486Z"
    }
   },
   "outputs": [],
   "source": [
    "class PoolConvFC(nn.Module):\n",
    "  \"\"\"Deep network with one convolutional layer followed by a max pooling layer\n",
    "  and one fully connected layer\n",
    "\n",
    "  Args:\n",
    "    h_in (int): height of input image, in pixels (i.e. number of rows)\n",
    "    w_in (int): width of input image, in pixels (i.e. number of columns)\n",
    "\n",
    "  Attributes:\n",
    "    conv (nn.Conv2d): filter weights of convolutional layer\n",
    "    pool (nn.MaxPool2d): max pooling layer\n",
    "    dims (tuple of ints): dimensions of output from pool layer\n",
    "    fc (nn.Linear): weights and biases of fully connected layer\n",
    "    out (nn.Linear): weights and biases of output layer\n",
    "\n",
    "  \"\"\"\n",
    "\n",
    "  def __init__(self, h_in, w_in):\n",
    "    super().__init__()\n",
    "    C_in = 1  # input stimuli have only 1 input channel\n",
    "    C_out = 8  # number of output channels (i.e. of convolutional kernels to convolve the input with)\n",
    "    K = 5  # size of each convolutional kernel\n",
    "    Kpool = 2  # size of patches over which to pool\n",
    "    self.conv = nn.Conv2d(C_in, C_out, kernel_size=K, padding=K//2)  # add padding to ensure that each channel has same dimensionality as input\n",
    "    self.pool = nn.MaxPool2d(Kpool)\n",
    "    self.dims = (C_out, h_in // Kpool, w_in // Kpool)  # dimensions of pool layer output\n",
    "    self.fc = nn.Linear(np.prod(self.dims), 10)  # flattened pool output --> 10D representation\n",
    "    self.out = nn.Linear(10, 1)  # 10D representation --> scalar\n",
    "\n",
    "  def forward(self, x):\n",
    "    \"\"\"Classify grating stimulus as tilted right or left\n",
    "\n",
    "    Args:\n",
    "      x (torch.Tensor): p x 48 x 64 tensor with pixel grayscale values for\n",
    "          each of p stimulus images.\n",
    "\n",
    "    Returns:\n",
    "      torch.Tensor: p x 1 tensor with network outputs for each input provided\n",
    "          in x. Each output should be interpreted as the probability of the\n",
    "          corresponding stimulus being tilted right.\n",
    "\n",
    "    \"\"\"\n",
    "    x = x.unsqueeze(1)  # p x 1 x 48 x 64, add a singleton dimension for the single stimulus channel\n",
    "    x = torch.relu(self.conv(x))  # output of convolutional layer\n",
    "    x = self.pool(x)  # output of pooling layer\n",
    "    x = x.view(-1, np.prod(self.dims))  # flatten pooling layer outputs into a vector\n",
    "    x = torch.relu(self.fc(x))  # output of fully connected layer\n",
    "    x = torch.sigmoid(self.out(x))  # network output\n",
    "    return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This pooling layer completes the CNN model trained above to perform orientation discrimination. We can think of this architecture as having two primary layers:\n",
    "1. a convolutional + pooling layer\n",
    "2. a fully connected layer\n",
    "\n",
    "We group together the convolution and pooling layers into one, as they really form one full unit of convolutional processing, where each patch of the image is passed through a convolutional filter and pooled with neighboring patches. It is standar practice to follow up any convolutional layer with a pooling layer, so they are generally treated as a single block of processing."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Bonus Section 2: Orientation discrimination as a binary classification problem\n",
    "\n",
    "What loss function should we minimize to optimize orientation discrimination performance? We first note that the orientation discrimination task is a **binary classification problem**, where the goal is to classify a given stimulus into one of two classes: being tilted left or being tilted right. \n",
    "\n",
    "Our goal is thus to output a high probability of the stimulus being tilted right (i.e. large $p$) whenever the stimulus is tilted right, and a high probability of the stimulus being tilted left (i.e. large $1-p \\Leftrightarrow$ small $p$) whenever the stimulus is tilted left.\n",
    "\n",
    "Let $\\tilde{y}^{(n)}$ be the label of the $n$th stimulus in the mini-batch, indicating its true tilt:\n",
    "\\begin{equation}\n",
    "  \\tilde{y}^{(n)} =\n",
    "  \\begin{cases}\n",
    "    1 &\\text{if stimulus }n\\text{ is tilted right} \\\\\n",
    "    0 &\\text{if stimulus }n\\text{ is tilted left}\n",
    "  \\end{cases}\n",
    "\\end{equation}\n",
    "Let $p^{(n)}$ be the predicted probability of that stimulus being tilted right assigned by our network. Note that that $1-p^{(n)}$ is the predicted probability of that stimulus being tilted left. We'd now like to modify the parameters so as to maximize the predicted probability of the true class $\\tilde{y}^{(n)}$. One way to formalize this is as maximizing the *log* probability\n",
    "\\begin{align}\n",
    "  \\log \\left( \\text{predicted probability of stimulus } n \\text{ being of class } \\tilde{y}^{(n)}\\right) &= \n",
    "  \\begin{cases}\n",
    "    \\log p^{(n)} &\\text{if }\\tilde{y}^{(n)} = 1 \\\\\n",
    "    \\log (1 - p^{(n)}) &\\text{if }\\tilde{y}^{(n)} = 0\n",
    "  \\end{cases}\n",
    "  \\\\\n",
    "  &= \\tilde{y}^{(n)} \\log p^{(n)} + (1 - \\tilde{y}^{(n)})\\log(1 - p^{(n)})\n",
    "\\end{align}\n",
    "You should recognize this expression as the log likelihood of the Bernoulli distribution under the predicted probability $p^{(n)}$. This is the same quantity that is maximized in logistic regression, where the predicted probability $p^{(n)}$ is just a simple linear sum of its inputs (rather than a complicated non-linear operation, like in the deep networks used here).\n",
    "\n",
    "To turn this into a loss function, we simply multiply it by -1, resulting in the so-called **binary cross-entropy**, or **negative log likelihood**. Summing over $P$ samples in a batch, the binary cross entropy loss is given by\n",
    "\\begin{equation}\n",
    "  L = -\\sum_{n=1}^P \\tilde{y}^{(n)} \\log p^{(n)} + (1 - \\tilde{y}^{(n)})\\log(1 - p^{(n)})\n",
    "\\end{equation}\n",
    "The binary cross-entropy loss can be implemented in PyTorch using the `nn.BCELoss()` loss function (cf. [documentation](https://pytorch.org/docs/master/generated/torch.nn.BCELoss.html)). \n",
    "\n",
    "Feel free to check out the code used to optimize the CNN in the `train()` function defined in the hidden cell of helper functions at the top of the notebook. Because the CNN's used here have lots of parameters, we have to use two tricks that we didn't use in the previous parts of this tutorial:\n",
    "1. We have to use *stochastic* gradient descent (SGD), rather than just gradient descent (GD).\n",
    "2. We have to use [momentum](https://distill.pub/2017/momentum/) in our SGD updates. This is easily incorporated into our PyTorch implementation by just setting the `momentum` argument of the built-in `optim.SGD` optimizer."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Bonus Section 3: RDM Z-Score Explanation\n",
    "\n",
    "If $r^{(s)}_i$ is the response of the $i$th neuron to the $s$th stimulus, then\n",
    "\\begin{gather}\n",
    "  M_{ss'} = 1 - \\frac{\\text{Cov}\\left[ r_i^{(s)}, r_i^{(s')} \\right]}{\\sqrt{\\text{Var}\\left[ r_i^{(s)} \\right] \\text{Var}\\left[ r_i^{(s')} \\right]}} = 1 - \\frac{\\sum_{i=1}^N (r_i^{(s)} - \\bar{r}^{(s)})(r_i^{(s')} - \\bar{r}^{(s')}) }{\\sqrt{\\sum_{i=1}^N \\left( r_i^{(s)} - \\bar{r}^{(s)} \\right)^2 \\sum_{i=1}^N \\left( r_i^{(s')} - \\bar{r}^{(s')} \\right)^2 }} \\\\\n",
    "  \\bar{r}^{(s)} = \\frac{1}{N} \\sum_{i=1}^N r_i^{(s)}\n",
    "\\end{gather}\n",
    "This can be computed efficiently by using the $z$-scored responses\n",
    "\\begin{equation}\n",
    "  z_i^{(s)} = \\frac{r_i^{(s)} - \\bar{r}^{(s)}}{\\sqrt{\\frac{1}{N}\\sum_{i=1}^N \\left( r_i^{(s)} - \\bar{r}^{(s)} \\right)^2}} \\Rightarrow M_{ss'} = 1 - \\frac{1}{N}\\sum_{i=1}^N z_i^{(s)}z_i^{(s')}\n",
    "\\end{equation}\n",
    "such that the full matrix can be computed through the matrix multiplication\n",
    "\\begin{gather}\n",
    "  \\mathbf{M} = 1 - \\frac{1}{N} \\mathbf{ZZ}^T \\\\\n",
    "  \\mathbf{Z} = \n",
    "  \\begin{bmatrix}\n",
    "    z_1^{(1)} & z_2^{(1)} & \\ldots & z_N^{(1)} \\\\\n",
    "    z_1^{(2)} & z_2^{(2)} & \\ldots & z_N^{(2)} \\\\\n",
    "    \\vdots & \\vdots & \\ddots & \\vdots \\\\\n",
    "    z_1^{(S)} & z_2^{(S)} & \\ldots & z_N^{(S)}\n",
    "  \\end{bmatrix}\n",
    "\\end{gather}\n",
    "\n",
    "\n",
    "where $S$ is the total number of stimuli. Note that $\\mathbf{Z}$ is an $S \\times N$ matrix, and $\\mathbf{M}$ is an $S \\times S$ matrix."
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "include_colab_link": true,
   "name": "W2D1_Tutorial3",
   "provenance": [],
   "toc_visible": true
  },
  "kernel": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
